#----------------------------------------------------------------------
#  MDLSM method test - 2d rod thermal analysis
#  Author: Andrea Pavan
#  Date: 28/02/2023
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using SparseArrays;
using PyPlot;
include("utils.jl");


#problem definition
l1 = 10.0;      #domain x size
l2 = 1.0;       #domain y size
k = 0.1;        #thermal conductivity
Tleft = 20;     #temperature on the left edge
Tright = 100;       #temperature on the right edge

meshSize = 0.2;
#meshSize = 0.025;
surfaceMeshSize = meshSize;
minNeighbors = 5;
minSearchRadius = meshSize;


#pointcloud generation
time1 = time();
pointcloud = ElasticArray{Float64}(undef,2,0);      #2xN matrix containing the coordinates [X;Y] of each node
boundaryNodes = Vector{Int};        #indices of the boundary nodes
normals = ElasticArray{Float64}(undef,2,0);     #2xN matrix containing the components [nx;ny] of the normal of each boundary node
for i=0+surfaceMeshSize:surfaceMeshSize:l2-surfaceMeshSize
    append!(pointcloud, [0,i]);
    append!(normals, [-1,0]);
    append!(pointcloud, [l1,i]);
    append!(normals, [1,0]);
end
for i=0+surfaceMeshSize:surfaceMeshSize:l1-surfaceMeshSize
#for i=0:surfaceMeshSize:l1
    append!(pointcloud, [i,0]);
    append!(normals, [0,-1]);
    append!(pointcloud, [i,l2]);
    append!(normals, [0,1]);
end
boundaryNodes = collect(range(1,size(pointcloud,2)));
for x=0+meshSize:meshSize:l1-meshSize
    for y=0+meshSize:meshSize:l2-meshSize
        newP = [x,y]+(rand(Float64,2).-0.5).*meshSize/5;
        insertP = true;
        for j in boundaryNodes
            if (newP[1]-pointcloud[1,j])^2+(newP[2]-pointcloud[2,j])^2<(0.75*meshSize)^2
                insertP = false;
            end
        end
        if insertP
            append!(pointcloud, newP);
            append!(normals, [0,0]);
        end
    end
end
internalNodes = collect(range(1+length(boundaryNodes),size(pointcloud,2)));
println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));

#pointcloud plot
#=figure();
plot(pointcloud[1,boundaryNodes],pointcloud[2,boundaryNodes],"r.");
plot(pointcloud[1,internalNodes],pointcloud[2,internalNodes],"k.");
title("Pointcloud plot");
axis("equal");
display(gcf());=#


#neighbor search
time2 = time();
N = size(pointcloud,2);     #number of nodes
neighbors = Vector{Vector{Int}}(undef,N);       #vector containing N vectors of the indices of each node neighbors
Nneighbors = zeros(Int,N);      #number of neighbors of each node
for i=1:N
    searchradius = minSearchRadius;
    while Nneighbors[i]<minNeighbors
        neighbors[i] = Int[];
        #check every other node
        for j=1:N
            if j!=i && all(abs.(pointcloud[:,j]-pointcloud[:,i]).<searchradius)
                push!(neighbors[i],j);
            end
        end
        unique!(neighbors[i]);
        Nneighbors[i] = length(neighbors[i]);
        searchradius += minSearchRadius/2;
    end
end
println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w = Vector{Vector{Float64}}(undef,N);       #neighbors weights
for i=1:N
    P[i] = Array{Float64}(undef,2,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    for j=1:Nneighbors[i]
        w[i][j] = exp(-6*r2[i][j]/r2max);
        #w[i][j] = 1.0;
    end
end


#boundary conditions
g1 = zeros(Float64,N);
g2 = zeros(Float64,N);
g3 = zeros(Float64,N);
for i in boundaryNodes
    if pointcloud[1,i]<=0+1e-6
        #left surface
        g1[i] = 1.0;
        g2[i] = 0.0;
        g3[i] = Tleft;
    elseif pointcloud[1,i]>=l1-1e-6
        #right surface
        g1[i] = 1.0;
        g2[i] = 0.0;
        g3[i] = Tright;
    else
        #everywhere else
        g1[i] = 0.0;
        g2[i] = 1.0;
        g3[i] = 0.0;
    end
end


#stencil coefficients
Cx = Vector{Array{Float64}}(undef,N);       #neighbors stencil coefficients to approximate d/dx
Cy = Vector{Array{Float64}}(undef,N);       #neighbors stencil coefficients to approximate d/dy

#=
for i=1:N
    den = sum(w[i].*P[i][1,:].^2)*sum(w[i].*P[i][2,:].^2)-sum(w[i].*P[i][1,:].*P[i][2,:])^2;
    a1x = sum(w[i].*P[i][2,:].^2)/den;
    a1y = sum(w[i].*P[i][1,:].^2)/den;
    a2 = sum(w[i].*P[i][1,:].*P[i][2,:])/den;
    Cx[i] = w[i].*P[i][1,:].*a1x + w[i].*P[i][2,:].*a2;
    Cy[i] = w[i].*P[i][2,:].*a1y + w[i].*P[i][1,:].*a2;
end=#

C = Vector{Matrix}(undef,N);        #stencil coefficients matrices
condC = zeros(N);       #stencil condition number
for i=1:N
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,1+Nneighbors[i],3);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j]];
    end
    VF = svd(V);
    C[i] = transpose(VF.Vt)*inv(Diagonal(VF.S))*transpose(VF.U);
    condC[i] = cond(C[i]);
    Cx[i] = C[i][2,:];
    Cy[i] = C[i][3,:];
end

#=C = Vector{Matrix}(undef,N);        #stencil coefficients matrices
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,1+Nneighbors[i],6);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 2*k, 2*k, 0];
    W = Diagonal(vcat(w[i],2));
    (Q,R) = qr(W*V);
    C[i] = inv(R)*transpose(Matrix(Q))*W;
    Cx[i] = C[i][2,:];
    Cy[i] = C[i][3,:];
end
for i in boundaryNodes
    #println("Boundary node: ",i);
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,2+Nneighbors[i],6);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 2*k, 2*k, 0];
    V[2+Nneighbors[i],:] = [g1[i], g2[i]*normals[1,i], g2[i]*normals[2,i], 0, 0, 0];
    W = Diagonal(vcat(w[i],2,2));
    (Q,R) = qr(W*V);
    C[i] = inv(R)*transpose(Matrix(Q))*W;
    Cx[i] = C[i][2,:];
    Cy[i] = C[i][3,:];
end=#

println("Calculated stencil coefficients in ", round(time()-time3,digits=2), " s");
println("Stencil properties:");
println("  Max cond(C): ", trunc(Int,maximum(condC)));
println("  Avg cond(C): ", trunc(Int,sum(condC)/N));
println("  Min cond(C): ", trunc(Int,minimum(condC)));


#matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i=1:N
    for j=1:Nneighbors[i]
        #heat equation
        push!(rows, i);
        push!(cols, N+neighbors[i][j]);
        push!(vals, Cx[i][j]);
        push!(rows, i);
        push!(cols, 2*N+neighbors[i][j]);
        push!(vals, Cy[i][j]);

        #qx definition
        push!(rows, N+i);
        push!(cols, neighbors[i][j]);
        push!(vals, -k*Cx[i][j]);
        push!(rows, N+i);
        push!(cols, N+i);
        push!(vals, -1);

        #qy definition
        push!(rows, 2*N+i);
        push!(cols, neighbors[i][j]);
        push!(vals, -k*Cy[i][j]);
        push!(rows, 2*N+i);
        push!(cols, 2*N+i);
        push!(vals, -1);
    end
end
for i=1:lastindex(boundaryNodes)
    #boundary condition
    push!(rows, 3*N+i);
    push!(cols, i);
    push!(vals, g1[boundaryNodes[i]]);
    push!(rows, 3*N+i);
    push!(cols, N+i);
    push!(vals, -g2[boundaryNodes[i]]*normals[1,boundaryNodes[i]]/k);
    push!(rows, 3*N+i);
    push!(cols, 2*N+i);
    push!(vals, -g2[boundaryNodes[i]]*normals[2,boundaryNodes[i]]/k);
end
L = sparse(rows,cols,vals,3*N+length(boundaryNodes),3*N);

wpde = 1.0;       #least squares weight for the pde
wbc = 2.0;        #least squares weight for the boundary condition
W2 = Diagonal(vcat(wpde*ones(3*N), wbc*ones(length(boundaryNodes))));
M = transpose(L)*W2*L;
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");


#linear system solution
time5 = time();
g = zeros(3*N+length(boundaryNodes));       #rhs vector
for i=1:N
    g[i] = 0;
    g[N+i] = 0;
    g[2*N+i] = 0;
end
for i=1:lastindex(boundaryNodes)
    g[3*N+i] = g3[boundaryNodes[i]];
end
b = transpose(L)*W2*g;
sol = qr(M)\b;
println("Linear system solved in ", round(time()-time5,digits=2), " s");


#solution plots
T = sol[1:N];
qx = sol[N+1:2*N];
qy = sol[2*N+1:3*N];

#=figure();
plot(pointcloud[1,:],T,"k.");
title("Temperature field");
display(gcf());=#

figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=T,cmap="inferno");
colorbar();
title("Numerical solution - temperature");
axis("equal");
display(gcf());

figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=qx,cmap="jet");
colorbar();
title("Numerical solution - qx heat flux");
axis("equal");
display(gcf());

figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=qy,cmap="jet");
colorbar();
title("Numerical solution - qy heat flux");
axis("equal");
display(gcf());
